{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# COMPSCI 389: Introduction to Machine Learning\n",
"# Generative AI\n",
"\n",
"Generative AI techniques are methods for generating new content like text, images, music, or other data, often mimicking some aspects of human creativity.\n",
"\n",
"Unlike supervised learning, which involves learning from *labeled* data, generative AI aims to learn the underlying patters, features, and distributions of a dataset so that it can generate new, similar data. Hence, it is a form of *unsupervised* learning. \n",
"\n",
"Two core methods in generative AI are **variational autoencoders** (VAEs) and **generative adversarial networks** (GANs). In this notebook we review VAEs and GANs. We then provide an overview of how **large language models** (LLMs) are trained."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Variational Autoencoders (VAEs)\n",
"\n",
"VAEs are trained from a data set like a set of images. They learn to create new rows (data points) that resemble those in the provided data set. They do this by converting this unsupervised learning problem into a supervised learning problem, and then applying the methods that we have discussed (gradient descent on a loss function for a parametric model).\n",
"\n",
"Specifically, VAEs take an input, $X_i$, and try to output $X_i$. That is, their loss function measures how close the output is to the input, $X_i$. Such a loss function is called a **reconstruction loss**. For image data, the *mean squared error* (MSE) between the pixel values of the original and reconstructed images is one common choice.\n",
"\n",
"This may seem simple: just define a parametric model that maps the input to the output! The key insight in VAEs is that the parametric model can be deliberately constructed to force the model to do more than this. The diagram below depicts a common *artificial neural network* (ANN) architecture for a VAE:\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The smaller boxes in the middle represent layers with fewer units. In order to reconstruct the input, $x$, with this architecture, the network must learn a smaller representation that incodes the content of the image. This smaller representation is called an **embedding** or a **latent representation**. This figure writes **latent space** to denote the space of all vectors that can be represented by the middle layer.\n",
"\n",
"As an exmaple, consider the problem of reconstructing images of cats. The input might be a large high-resolution RGB image. Even at a resolution of 1024x768 (with three channels, R,G, and B), an image is represented as 2,359,296 numbers. If the latent space is represented by a layer with 100 units, the network must learn to represent the entire image (2,359,296 numbers) with just 100 numbers! To do this, it might learn features like the breed of the cat (to determine the color), the age of the cat, the angle of the cat, whether the background is indoors or outdoors, etc.\n",
"\n",
"Intuitively, the \"Encoder\" part of the network (the part of the network before the bottleneck of a small layer that is the latent space) will learn to map an image into the latent space (into the smaller representation describing the image). The \"Decoder\" part of the network (the part of the network after the bottleneck) will then learn to map this latent representation back to an image.\n",
"\n",
"Note that none of this behavior is hard-coded into the methods. The network is simply designed to have a small layer in the middle, and then trained to minimize the reconstruction loss. The network then creates an encoder and decoder on its own!\n",
"\n",
"Once an encoder and decoder have been trained, you can generate new images by providing random vectors as input to the decoder, essentially \"making up\" the representation is latent space and asking the decoder to reconstruct the image.\n",
"\n",
"However, without additional mechanisms, the autoencoder may learn to only use certain values in the latent space, and so new random vectors in the latent space may not map to meaningful images. **Variational** autoencoders modify the reconstruction loss function to also encourage the network to make the latent representation have a Gaussian distribution. This way, new outputs can be generated by sampling vectors of Gaussian noise and treating them as the latent representation to be decoded. The details of this process is beyond the scope of this class, but note that one common reconstruction loss for autoencoders is called the **evidence lower bound** or (ELBO). This is simply a loss function that balances A) ensuring that the distribution of latent representations that results from the training data is roughly Gaussian, with B) the objective of reconstructing the output."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Example VAE\n",
"\n",
"Although the details are beyond the scope of this course, notice that training a VAE in PyTorch is not significantly different from the classification and regression examples we have seen. The main difference is that the forwards pass doesn't just output the attempted reconstruction of the image: it also outputs information related to the latent representation of the image. This is necessary for the loss function to push the network towards a latent representation of the input data that is Gaussian."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torchvision\n",
"from torch import nn, optim\n",
"from torchvision import transforms\n",
"from torch.nn import functional as F\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# VAE model\n",
"class VAE(nn.Module):\n",
" def __init__(self):\n",
" super(VAE, self).__init__() # Call the nn.Module initializer\n",
" self.fc1 = nn.Linear(784, 400)\n",
" self.fc21 = nn.Linear(400, 20) # mu layer (mean of the Gaussian)\n",
" self.fc22 = nn.Linear(400, 20) # logvar layer (log of the variance)\n",
" self.fc3 = nn.Linear(20, 400)\n",
" self.fc4 = nn.Linear(400, 784)\n",
"\n",
" def encode(self, x):\n",
" h1 = F.relu(self.fc1(x))\n",
" return self.fc21(h1), self.fc22(h1)\n",
"\n",
" def reparameterize(self, mu, logvar):\n",
" std = torch.exp(0.5*logvar)\n",
" eps = torch.randn_like(std) # The length is that of \"std\", but these are samples of N(0,1)\n",
" return mu + eps*std\n",
"\n",
" def decode(self, z):\n",
" h3 = F.relu(self.fc3(z))\n",
" return torch.sigmoid(self.fc4(h3))\n",
"\n",
" def forward(self, x):\n",
" mu, logvar = self.encode(x.view(-1, 784))\n",
" z = self.reparameterize(mu, logvar)\n",
" return self.decode(z), mu, logvar"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Loss function (ELBO)\n",
"def loss_function(recon_x, x, mu, logvar):\n",
" BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')\n",
" KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())\n",
" return BCE + KLD"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
"Failed to download (trying next):\n",
"HTTP Error 404: Not Found\n",
"\n",
"Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz\n",
"Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data\\MNIST\\raw\\train-images-idx3-ubyte.gz\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 9912422/9912422 [00:00<00:00, 51984454.62it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting ./data\\MNIST\\raw\\train-images-idx3-ubyte.gz to ./data\\MNIST\\raw\n",
"\n",
"Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
"Failed to download (trying next):\n",
"HTTP Error 404: Not Found\n",
"\n",
"Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz\n",
"Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data\\MNIST\\raw\\train-labels-idx1-ubyte.gz\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 28881/28881 [00:00<00:00, 1806593.30it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting ./data\\MNIST\\raw\\train-labels-idx1-ubyte.gz to ./data\\MNIST\\raw\n",
"\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
"Failed to download (trying next):\n",
"HTTP Error 404: Not Found\n",
"\n",
"Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz\n",
"Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data\\MNIST\\raw\\t10k-images-idx3-ubyte.gz\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1648877/1648877 [00:00<00:00, 16074085.58it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting ./data\\MNIST\\raw\\t10k-images-idx3-ubyte.gz to ./data\\MNIST\\raw\n",
"\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
"Failed to download (trying next):\n",
"HTTP Error 404: Not Found\n",
"\n",
"Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz\n",
"Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data\\MNIST\\raw\\t10k-labels-idx1-ubyte.gz\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 4542/4542 [00:00, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting ./data\\MNIST\\raw\\t10k-labels-idx1-ubyte.gz to ./data\\MNIST\\raw\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# Data loading\n",
"transform = transforms.Compose([transforms.ToTensor()])\n",
"trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)\n",
"trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# Model and optimizer\n",
"model = VAE()\n",
"optimizer = optim.Adam(model.parameters(), lr=1e-3)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1, Loss: 163.6830\n",
"Epoch 2, Loss: 121.0160\n",
"Epoch 3, Loss: 114.4315\n",
"Epoch 4, Loss: 111.5428\n",
"Epoch 5, Loss: 109.8738\n",
"Epoch 6, Loss: 108.6868\n",
"Epoch 7, Loss: 107.8526\n",
"Epoch 8, Loss: 107.2281\n",
"Epoch 9, Loss: 106.6957\n",
"Epoch 10, Loss: 106.3061\n"
]
}
],
"source": [
"# Training\n",
"for epoch in range(1, 11):\n",
" train_loss = 0\n",
" for data, _ in trainloader:\n",
" recon_batch, mu, logvar = model.forward(data)\n",
" optimizer.zero_grad()\n",
" loss = loss_function(recon_batch, data, mu, logvar)\n",
" loss.backward()\n",
" train_loss += loss.item()\n",
" optimizer.step()\n",
" print(f'Epoch {epoch}, Loss: {train_loss / len(trainloader.dataset):.4f}')"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Generating images\n",
"def show_generated_images(model, num_images=10):\n",
" with torch.no_grad():\n",
" z = torch.randn(num_images, 20)\n",
" sample = model.decode(z).cpu()\n",
" sample = sample.view(num_images, 28, 28)\n",
"\n",
" fig, axs = plt.subplots(1, num_images, figsize=(num_images, 1))\n",
" for i in range(num_images):\n",
" axs[i].imshow(sample[i].numpy(), cmap='gray')\n",
" axs[i].axis('off')\n",
" plt.show()\n",
"\n",
"show_generated_images(model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"After just 11 epochs (1.5 minutes on my CPU) of training, we're starting to see images that looks like hand-written letters! With more training these could likely be further improved."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Generative Adversarial Networks (GANs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"GANs use two neural networks that learn from each other, competing in a king of tug-of-war. These two networks are called the **generator** and the **discriminator**.\n",
"\n",
"#### Generator\n",
"\n",
"The generator is tasked with creating fake data points (e.g., fake images). Its goal is to create images that are indistinguishable from the real ones in the data set. It starts by creating the image from random noise. That is, it takes random noise as input, and produces an image as output.\n",
"\n",
"#### Discriminator\n",
"\n",
"The discriminator takes both the fake images from the generator and the real images from the training data set as its input data (one image at a time). Its job is to learn to differentiate between the two, discerning real from fake.\n",
"\n",
"#### Training\n",
"\n",
"The loss function used by the generator penalizes it for producing images that the discriminator is able to detect as fakes. The loss function for the discriminator penalizes it fair failing to detect the generator's fakes.\n",
"\n",
"During the early stages of training, the generator's images are easily spotted as fakes. However, as training progresses, it starts to understand what makes an image more believable. It learns complex representations - not specific features like the color or shape of an object, but deeper aspects, like the way light and shadow play on surfaces or how different objects in an image relate to each other in space.\n",
"\n",
"What's truly remarkable is that this entire process evolves without explicit instructions on what to learn. The networks, through their adversarial training, figure out the details. Eventually, if the training is successful, the generator becomes so skilled that the discriminator can't tell its creations from real images.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Example GAN"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch [0/15], Step [100/938], d_loss: 0.1968, g_loss: 2.1253\n",
"Epoch [0/15], Step [200/938], d_loss: 0.1253, g_loss: 3.2851\n",
"Epoch [0/15], Step [300/938], d_loss: 0.1698, g_loss: 4.1563\n",
"Epoch [0/15], Step [400/938], d_loss: 0.0610, g_loss: 5.5070\n",
"Epoch [0/15], Step [500/938], d_loss: 0.0136, g_loss: 6.1925\n",
"Epoch [0/15], Step [600/938], d_loss: 0.0244, g_loss: 6.1424\n",
"Epoch [0/15], Step [700/938], d_loss: 0.0265, g_loss: 6.5670\n",
"Epoch [0/15], Step [800/938], d_loss: 0.0337, g_loss: 6.3879\n",
"Epoch [0/15], Step [900/938], d_loss: 0.0803, g_loss: 7.4044\n",
"Epoch [1/15], Step [100/938], d_loss: 0.1964, g_loss: 11.1679\n",
"Epoch [1/15], Step [200/938], d_loss: 0.1964, g_loss: 6.7413\n",
"Epoch [1/15], Step [300/938], d_loss: 0.1970, g_loss: 3.5620\n",
"Epoch [1/15], Step [400/938], d_loss: 0.2756, g_loss: 5.0921\n",
"Epoch [1/15], Step [500/938], d_loss: 0.3247, g_loss: 4.7702\n",
"Epoch [1/15], Step [600/938], d_loss: 0.3061, g_loss: 4.3312\n",
"Epoch [1/15], Step [700/938], d_loss: 0.2647, g_loss: 6.2605\n",
"Epoch [1/15], Step [800/938], d_loss: 0.3247, g_loss: 3.7898\n",
"Epoch [1/15], Step [900/938], d_loss: 0.5217, g_loss: 3.3301\n",
"Epoch [2/15], Step [100/938], d_loss: 0.2458, g_loss: 2.7131\n",
"Epoch [2/15], Step [200/938], d_loss: 0.4238, g_loss: 4.3987\n",
"Epoch [2/15], Step [300/938], d_loss: 0.5795, g_loss: 2.4602\n",
"Epoch [2/15], Step [400/938], d_loss: 0.5659, g_loss: 3.4101\n",
"Epoch [2/15], Step [500/938], d_loss: 0.4948, g_loss: 3.2021\n",
"Epoch [2/15], Step [600/938], d_loss: 0.4680, g_loss: 3.7562\n",
"Epoch [2/15], Step [700/938], d_loss: 0.6839, g_loss: 2.6689\n",
"Epoch [2/15], Step [800/938], d_loss: 0.8165, g_loss: 3.4143\n",
"Epoch [2/15], Step [900/938], d_loss: 0.4250, g_loss: 2.2062\n",
"Epoch [3/15], Step [100/938], d_loss: 0.5896, g_loss: 2.6581\n",
"Epoch [3/15], Step [200/938], d_loss: 0.4780, g_loss: 3.5904\n",
"Epoch [3/15], Step [300/938], d_loss: 0.6242, g_loss: 2.2625\n",
"Epoch [3/15], Step [400/938], d_loss: 0.6681, g_loss: 2.8004\n",
"Epoch [3/15], Step [500/938], d_loss: 0.7445, g_loss: 3.0726\n",
"Epoch [3/15], Step [600/938], d_loss: 0.5506, g_loss: 3.5891\n",
"Epoch [3/15], Step [700/938], d_loss: 0.7191, g_loss: 2.5736\n",
"Epoch [3/15], Step [800/938], d_loss: 0.7074, g_loss: 2.7978\n",
"Epoch [3/15], Step [900/938], d_loss: 0.5882, g_loss: 1.9058\n",
"Epoch [4/15], Step [100/938], d_loss: 0.9155, g_loss: 1.6220\n",
"Epoch [4/15], Step [200/938], d_loss: 0.6090, g_loss: 2.3546\n",
"Epoch [4/15], Step [300/938], d_loss: 0.7083, g_loss: 2.6362\n",
"Epoch [4/15], Step [400/938], d_loss: 0.8851, g_loss: 1.7231\n",
"Epoch [4/15], Step [500/938], d_loss: 0.7574, g_loss: 1.7395\n",
"Epoch [4/15], Step [600/938], d_loss: 0.8677, g_loss: 2.0268\n",
"Epoch [4/15], Step [700/938], d_loss: 0.6598, g_loss: 1.8621\n",
"Epoch [4/15], Step [800/938], d_loss: 0.6667, g_loss: 2.3757\n",
"Epoch [4/15], Step [900/938], d_loss: 0.5857, g_loss: 1.9318\n",
"Epoch [5/15], Step [100/938], d_loss: 0.6102, g_loss: 2.1001\n",
"Epoch [5/15], Step [200/938], d_loss: 0.7180, g_loss: 2.3402\n",
"Epoch [5/15], Step [300/938], d_loss: 0.6467, g_loss: 2.7362\n",
"Epoch [5/15], Step [400/938], d_loss: 0.8340, g_loss: 1.7500\n",
"Epoch [5/15], Step [500/938], d_loss: 0.6634, g_loss: 1.8997\n",
"Epoch [5/15], Step [600/938], d_loss: 0.8033, g_loss: 2.6520\n",
"Epoch [5/15], Step [700/938], d_loss: 0.9429, g_loss: 1.7822\n",
"Epoch [5/15], Step [800/938], d_loss: 0.8467, g_loss: 2.0240\n",
"Epoch [5/15], Step [900/938], d_loss: 0.6981, g_loss: 2.1250\n",
"Epoch [6/15], Step [100/938], d_loss: 0.9844, g_loss: 1.8475\n",
"Epoch [6/15], Step [200/938], d_loss: 0.7247, g_loss: 2.3818\n",
"Epoch [6/15], Step [300/938], d_loss: 0.7758, g_loss: 1.2705\n",
"Epoch [6/15], Step [400/938], d_loss: 0.6998, g_loss: 2.2222\n",
"Epoch [6/15], Step [500/938], d_loss: 0.9477, g_loss: 1.8426\n",
"Epoch [6/15], Step [600/938], d_loss: 0.6771, g_loss: 1.9632\n",
"Epoch [6/15], Step [700/938], d_loss: 0.9380, g_loss: 1.7148\n",
"Epoch [6/15], Step [800/938], d_loss: 0.6734, g_loss: 1.9416\n",
"Epoch [6/15], Step [900/938], d_loss: 0.8901, g_loss: 1.5339\n",
"Epoch [7/15], Step [100/938], d_loss: 0.9863, g_loss: 1.7283\n",
"Epoch [7/15], Step [200/938], d_loss: 0.7632, g_loss: 2.4689\n",
"Epoch [7/15], Step [300/938], d_loss: 0.9671, g_loss: 1.7843\n",
"Epoch [7/15], Step [400/938], d_loss: 0.7342, g_loss: 1.7388\n",
"Epoch [7/15], Step [500/938], d_loss: 0.7367, g_loss: 1.4516\n",
"Epoch [7/15], Step [600/938], d_loss: 0.9533, g_loss: 1.7038\n",
"Epoch [7/15], Step [700/938], d_loss: 0.9913, g_loss: 2.5165\n",
"Epoch [7/15], Step [800/938], d_loss: 0.6771, g_loss: 1.8887\n",
"Epoch [7/15], Step [900/938], d_loss: 0.8115, g_loss: 2.2673\n",
"Epoch [8/15], Step [100/938], d_loss: 0.8134, g_loss: 2.1915\n",
"Epoch [8/15], Step [200/938], d_loss: 0.7089, g_loss: 2.4003\n",
"Epoch [8/15], Step [300/938], d_loss: 0.9474, g_loss: 1.8414\n",
"Epoch [8/15], Step [400/938], d_loss: 0.9899, g_loss: 1.6218\n",
"Epoch [8/15], Step [500/938], d_loss: 0.8158, g_loss: 1.9243\n",
"Epoch [8/15], Step [600/938], d_loss: 0.7556, g_loss: 1.8116\n",
"Epoch [8/15], Step [700/938], d_loss: 0.7415, g_loss: 2.0091\n",
"Epoch [8/15], Step [800/938], d_loss: 0.6599, g_loss: 2.2918\n",
"Epoch [8/15], Step [900/938], d_loss: 0.8394, g_loss: 2.2152\n",
"Epoch [9/15], Step [100/938], d_loss: 0.8547, g_loss: 2.2970\n",
"Epoch [9/15], Step [200/938], d_loss: 0.9285, g_loss: 2.1147\n",
"Epoch [9/15], Step [300/938], d_loss: 0.9043, g_loss: 2.4567\n",
"Epoch [9/15], Step [400/938], d_loss: 1.0239, g_loss: 1.4896\n",
"Epoch [9/15], Step [500/938], d_loss: 0.7885, g_loss: 1.6315\n",
"Epoch [9/15], Step [600/938], d_loss: 0.6050, g_loss: 1.9700\n",
"Epoch [9/15], Step [700/938], d_loss: 0.7771, g_loss: 1.6589\n",
"Epoch [9/15], Step [800/938], d_loss: 0.7726, g_loss: 1.6334\n",
"Epoch [9/15], Step [900/938], d_loss: 0.9044, g_loss: 2.0588\n",
"Epoch [10/15], Step [100/938], d_loss: 0.9735, g_loss: 2.0161\n",
"Epoch [10/15], Step [200/938], d_loss: 0.7590, g_loss: 2.6000\n",
"Epoch [10/15], Step [300/938], d_loss: 0.7545, g_loss: 1.9695\n",
"Epoch [10/15], Step [400/938], d_loss: 0.6759, g_loss: 2.0362\n",
"Epoch [10/15], Step [500/938], d_loss: 0.6649, g_loss: 1.5897\n",
"Epoch [10/15], Step [600/938], d_loss: 0.8588, g_loss: 1.8251\n",
"Epoch [10/15], Step [700/938], d_loss: 0.6679, g_loss: 2.3880\n",
"Epoch [10/15], Step [800/938], d_loss: 0.7888, g_loss: 1.7575\n",
"Epoch [10/15], Step [900/938], d_loss: 0.9048, g_loss: 1.1682\n",
"Epoch [11/15], Step [100/938], d_loss: 0.8266, g_loss: 1.4023\n",
"Epoch [11/15], Step [200/938], d_loss: 0.8397, g_loss: 1.3329\n",
"Epoch [11/15], Step [300/938], d_loss: 0.7491, g_loss: 1.8133\n",
"Epoch [11/15], Step [400/938], d_loss: 0.9517, g_loss: 1.9995\n",
"Epoch [11/15], Step [500/938], d_loss: 0.8492, g_loss: 1.2039\n",
"Epoch [11/15], Step [600/938], d_loss: 0.9957, g_loss: 1.6000\n",
"Epoch [11/15], Step [700/938], d_loss: 0.6942, g_loss: 1.1894\n",
"Epoch [11/15], Step [800/938], d_loss: 0.7000, g_loss: 2.2242\n",
"Epoch [11/15], Step [900/938], d_loss: 0.7945, g_loss: 1.6052\n",
"Epoch [12/15], Step [100/938], d_loss: 0.8775, g_loss: 1.5036\n",
"Epoch [12/15], Step [200/938], d_loss: 1.1906, g_loss: 1.6118\n",
"Epoch [12/15], Step [300/938], d_loss: 1.0698, g_loss: 1.4668\n",
"Epoch [12/15], Step [400/938], d_loss: 1.1755, g_loss: 1.7004\n",
"Epoch [12/15], Step [500/938], d_loss: 0.8775, g_loss: 1.9998\n",
"Epoch [12/15], Step [600/938], d_loss: 0.9339, g_loss: 2.0696\n",
"Epoch [12/15], Step [700/938], d_loss: 0.6873, g_loss: 1.9389\n",
"Epoch [12/15], Step [800/938], d_loss: 0.8550, g_loss: 1.9689\n",
"Epoch [12/15], Step [900/938], d_loss: 0.6931, g_loss: 2.1201\n",
"Epoch [13/15], Step [100/938], d_loss: 0.6622, g_loss: 2.2526\n",
"Epoch [13/15], Step [200/938], d_loss: 0.6650, g_loss: 2.4257\n",
"Epoch [13/15], Step [300/938], d_loss: 0.9299, g_loss: 2.2868\n",
"Epoch [13/15], Step [400/938], d_loss: 0.9252, g_loss: 1.6538\n",
"Epoch [13/15], Step [500/938], d_loss: 0.9582, g_loss: 1.5898\n",
"Epoch [13/15], Step [600/938], d_loss: 0.7105, g_loss: 2.0934\n",
"Epoch [13/15], Step [700/938], d_loss: 0.6426, g_loss: 2.4876\n",
"Epoch [13/15], Step [800/938], d_loss: 0.8371, g_loss: 1.9572\n",
"Epoch [13/15], Step [900/938], d_loss: 0.9026, g_loss: 1.5870\n",
"Epoch [14/15], Step [100/938], d_loss: 0.8441, g_loss: 1.7409\n",
"Epoch [14/15], Step [200/938], d_loss: 1.1538, g_loss: 2.4822\n",
"Epoch [14/15], Step [300/938], d_loss: 0.8320, g_loss: 2.2448\n",
"Epoch [14/15], Step [400/938], d_loss: 1.0065, g_loss: 1.7845\n",
"Epoch [14/15], Step [500/938], d_loss: 1.2467, g_loss: 1.5347\n",
"Epoch [14/15], Step [600/938], d_loss: 0.9260, g_loss: 1.8226\n",
"Epoch [14/15], Step [700/938], d_loss: 0.8086, g_loss: 1.5104\n",
"Epoch [14/15], Step [800/938], d_loss: 0.9717, g_loss: 1.4188\n",
"Epoch [14/15], Step [900/938], d_loss: 0.9018, g_loss: 2.4415\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import torch\n",
"import torchvision\n",
"import torchvision.transforms as transforms\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"from torch.autograd import Variable\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# Hyperparameters\n",
"batch_size = 64\n",
"learning_rate = 0.0002\n",
"epochs = 15\n",
"\n",
"# MNIST Dataset\n",
"transform = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(mean=(0.5,), std=(0.5,))\n",
"])\n",
"\n",
"train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)\n",
"train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)\n",
"\n",
"# Discriminator\n",
"class Discriminator(nn.Module):\n",
" def __init__(self):\n",
" super(Discriminator, self).__init__()\n",
" self.fc = nn.Sequential(\n",
" nn.Linear(784, 256),\n",
" nn.LeakyReLU(0.2),\n",
" nn.Linear(256, 256),\n",
" nn.LeakyReLU(0.2),\n",
" nn.Linear(256, 1),\n",
" nn.Sigmoid()\n",
" )\n",
"\n",
" def forward(self, x):\n",
" x = x.view(x.size(0), -1)\n",
" return self.fc(x)\n",
"\n",
"# Generator\n",
"class Generator(nn.Module):\n",
" def __init__(self):\n",
" super(Generator, self).__init__()\n",
" self.fc = nn.Sequential(\n",
" nn.Linear(100, 256),\n",
" nn.LeakyReLU(0.2),\n",
" nn.BatchNorm1d(256),\n",
" nn.Linear(256, 256),\n",
" nn.LeakyReLU(0.2),\n",
" nn.BatchNorm1d(256),\n",
" nn.Linear(256, 784),\n",
" nn.Tanh()\n",
" )\n",
"\n",
" def forward(self, x):\n",
" return self.fc(x)\n",
"\n",
"discriminator = Discriminator()\n",
"generator = Generator()\n",
"\n",
"# Loss and Optimizer\n",
"criterion = nn.BCELoss()\n",
"d_optimizer = optim.Adam(discriminator.parameters(), lr=learning_rate)\n",
"g_optimizer = optim.Adam(generator.parameters(), lr=learning_rate)\n",
"\n",
"# Training\n",
"for epoch in range(epochs):\n",
" for i, (images, _) in enumerate(train_loader):\n",
" current_batch_size = images.size(0)\n",
"\n",
" # Train Discriminator\n",
" real_images = Variable(images.view(current_batch_size, -1))\n",
" real_labels = Variable(torch.ones(current_batch_size, 1))\n",
" fake_labels = Variable(torch.zeros(current_batch_size, 1))\n",
"\n",
" # Real images loss\n",
" outputs = discriminator(real_images)\n",
" d_loss_real = criterion(outputs, real_labels)\n",
"\n",
" # Fake images loss\n",
" z = Variable(torch.randn(current_batch_size, 100))\n",
" fake_images = generator(z)\n",
" outputs = discriminator(fake_images)\n",
" d_loss_fake = criterion(outputs, fake_labels)\n",
"\n",
" # Backprop and optimize\n",
" d_loss = d_loss_real + d_loss_fake\n",
" d_optimizer.zero_grad()\n",
" g_optimizer.zero_grad()\n",
" d_loss.backward()\n",
" d_optimizer.step()\n",
"\n",
" # Train Generator\n",
" z = Variable(torch.randn(current_batch_size, 100))\n",
" fake_images = generator(z)\n",
" outputs = discriminator(fake_images)\n",
" g_loss = criterion(outputs, real_labels)\n",
"\n",
" # Backprop and optimize\n",
" d_optimizer.zero_grad()\n",
" g_optimizer.zero_grad()\n",
" g_loss.backward()\n",
" g_optimizer.step()\n",
"\n",
" if (i + 1) % 100 == 0:\n",
" print(f'Epoch [{epoch}/{epochs}], Step [{i+1}/{len(train_loader)}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')\n",
"\n",
"# Generate and show images\n",
"def show_generated_images(generator, num_images=10):\n",
" z = torch.randn(num_images, 100)\n",
" fake_images = generator(z)\n",
" fake_images = fake_images.view(fake_images.size(0), 28, 28)\n",
" fake_images = (fake_images + 1) / 2 # Rescale to [0, 1]\n",
"\n",
" fig, axs = plt.subplots(1, num_images, figsize=(num_images, 1))\n",
" for i in range(num_images):\n",
" axs[i].imshow(fake_images[i].detach().numpy(), cmap='gray')\n",
" axs[i].axis('off')\n",
"\n",
" plt.show()\n",
"\n",
"show_generated_images(generator)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Beyond VAEs and GANs\n",
"\n",
"Generative AI methods have been advancing rapidly. Note that current state of the art methods for generating images (or videos) from text prompts are not VAEs or GANs (DALL-E 1 was a variant of a VAE, but DALL-E 2 is not). More recent methods use **diffusion models**, which start with a pattern of random noise and then gradually transform this pattern into a coherent image. This transformation occurs in steps, each slightly reducing the randomness and shaping the noise into an image that corresponds to a given text input. For more on advanced generative AI techniques, I recommend taking COMPSCI 670, Computer Vision."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Large Language Models (LLMs)\n",
"\n",
"*Large language models* (LLMs) are parametric models applied to text (or audio). Models like GPT-4 are based on *transformers*, a type of artificial neural network architecture. \n",
"\n",
"ChatGPT using GPT-4 claims that it has 175 billion tunable parameters. The wikipedia page suggests 1 trillion to 1.76 trillion parameters. It was trained on 13 trillion tokens, or roughly 10 trillion words (roughly 50 terabytes of data).\n",
"\n",
"Perhaps surprisingly, GPT-4 uses supervised learning: given the tokens (words and parts of words) seen so far, it attempts to predict what the next word will be. Training such a large network on so much data was estimated to have cost roughly $60-$100 million.\n",
"\n",
"More specifically, the model takes as input the most recent \"tokens\" (parts of words), up to 32,768 tokens, and predicts what the next token will be. After training using supervised learning, a form of reinforcement learning from human feedback (RLHF) was used to finetune the model to produce more desirable answers to prompts.\n",
"\n",
"While there are some properties that are beyond the scope of this course (e.g., the details of the transformer architecture and the use of a mixture of experts), the core technology underlying GPT-4 and hence ChatGPT is simply a large parametric model trained using gradient descent on a large amount of data."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}